{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ObJAX Tutorial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the dependencies\n",
    "\n",
    "import numpy as np\n",
    "import jax\n",
    "import jax.numpy as jn\n",
    "import objax\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the MNIST dataset\n",
    "\n",
    "import tensorflow.keras as keras\n",
    "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
    "\n",
    "# ObJAX is like Torch, and puts the channel BEFORE the height and width.\n",
    "# The data starts out as (dataset_size, height=28, width=28)\n",
    "# We're going to make it (dataset_size, channels=1, height=28, width=28)\n",
    "x_train = x_train[:,None,:,:]/255.0\n",
    "x_test = x_test[:,None,:,:]/255.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f1dcc03b4e0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAANiklEQVR4nO3df4wc9XnH8c8n/kV8QGtDcF3j4ISQqE4aSHWBRNDKESUFImSiJBRLtVyJ5lALElRRW0QVBalVSlEIok0aySluHESgaQBhJTSNa6W1UKljg4yxgdaEmsau8QFOaxPAP/DTP24cHXD7vWNndmft5/2SVrs7z87Oo/F9PLMzO/t1RAjA8e9tbTcAoD8IO5AEYQeSIOxAEoQdSGJ6Pxc207PiBA31c5FAKq/qZzoYBzxRrVbYbV8s6XZJ0yT9bUTcXHr9CRrSeb6wziIBFGyIdR1rXe/G254m6auSLpG0WNIy24u7fT8AvVXnM/u5kp6OiGci4qCkeyQtbaYtAE2rE/YFkn4y7vnOatrr2B6xvcn2pkM6UGNxAOro+dH4iFgZEcMRMTxDs3q9OAAd1An7LkkLxz0/vZoGYADVCftGSWfZfpftmZKulLSmmbYANK3rU28Rcdj2tZL+SWOn3lZFxLbGOgPQqFrn2SPiQUkPNtQLgB7i67JAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJGoN2Wx7h6T9kl6TdDgihptoCkDzaoW98rGIeKGB9wHQQ+zGA0nUDXtI+oHtR2yPTPQC2yO2N9nedEgHai4OQLfq7sZfEBG7bJ8maa3tpyJi/fgXRMRKSSsl6WTPjZrLA9ClWlv2iNhV3Y9Kul/SuU00BaB5XYfd9pDtk44+lvRxSVubagxAs+rsxs+TdL/to+/zrYj4fiNdAWhc12GPiGcknd1gLwB6iFNvQBKEHUiCsANJEHYgCcIOJNHEhTApvPjZj3asvXP508V5nxqdV6wfPDCjWF9wd7k+e+dLHWtHNj9RnBd5sGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4zz5Ff/xH3+pY+9TQT8szn1lz4UvK5R2HX+5Yu/35j9Vc+LHrR6NndKwN3foLxXmnr3uk6XZax5YdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JwRP8GaTnZc+M8X9i35TXpZ58+r2PthQ+W/8+c82R5Hf/0V1ysz/zg/xbrt3zgvo61i97+SnHe7718YrH+idmdr5Wv65U4WKxvODBUrC854VDXy37P964u1t87srHr927ThlinfbF3wj8otuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATXs0/R0Hc2FGr13vvkerPrr39pScfan5+/qLzsfy3/5v0tS97TRUdTM/2VI8X60Jbdxfop6+8t1n91Zuff25+9o/xb/MejSbfstlfZHrW9ddy0ubbX2t5e3c/pbZsA6prKbvw3JF38hmk3SFoXEWdJWlc9BzDAJg17RKyXtPcNk5dKWl09Xi3p8ob7AtCwbj+zz4uIox+onpPUcTAz2yOSRiTpBM3ucnEA6qp9ND7GrqTpeKVHRKyMiOGIGJ6hWXUXB6BL3YZ9j+35klTdjzbXEoBe6DbsayStqB6vkPRAM+0A6JVJP7Pbvltjv1x+qu2dkr4g6WZJ37Z9laRnJV3RyyZRdvi5PR1rQ/d2rknSa5O899B3Xuyio2bs+b2PFuvvn1n+8/3S3vd1rC36u2eK8x4uVo9Nk4Y9IpZ1KB2bv0IBJMXXZYEkCDuQBGEHkiDsQBKEHUiCS1zRmulnLCzWv3LjV4r1GZ5WrP/D7b/ZsXbK7oeL8x6P2LIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKcZ0drnvrDBcX6h2eVh7LedrA8HPXcJ15+yz0dz9iyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASnGdHTx34xIc71h799G2TzF0eQej3r7uuWH/7v/1okvfPhS07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBeXb01H9f0nl7cqLL59GX/ddFxfrs7z9WrEexms+kW3bbq2yP2t46btpNtnfZ3lzdLu1tmwDqmspu/DckXTzB9Nsi4pzq9mCzbQFo2qRhj4j1kvb2oRcAPVTnAN21trdUu/lzOr3I9ojtTbY3HdKBGosDUEe3Yf+apDMlnSNpt6RbO70wIlZGxHBEDM+Y5MIGAL3TVdgjYk9EvBYRRyR9XdK5zbYFoGldhd32/HFPPylpa6fXAhgMk55nt323pCWSTrW9U9IXJC2xfY7GTmXukHR1D3vEAHvbSScV68t//aGOtX1HXi3OO/rFdxfrsw5sLNbxepOGPSKWTTD5jh70AqCH+LoskARhB5Ig7EAShB1IgrADSXCJK2rZftP7i/Xvnvo3HWtLt3+qOO+sBzm11iS27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOfZUfR/v/ORYn3Lb/9Vsf7jw4c61l76y9OL887S7mIdbw1bdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgvPsyU1f8MvF+vWf//tifZbLf0JXPra8Y+0d/8j16v3Elh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuA8+3HO08v/xGd/d2ex/pkTXyzW79p/WrE+7/OdtydHinOiaZNu2W0vtP1D20/Y3mb7umr6XNtrbW+v7uf0vl0A3ZrKbvxhSZ+LiMWSPiLpGtuLJd0gaV1EnCVpXfUcwICaNOwRsTsiHq0e75f0pKQFkpZKWl29bLWky3vVJID63tJndtuLJH1I0gZJ8yLi6I+EPSdpXod5RiSNSNIJmt1tnwBqmvLReNsnSrpX0vURsW98LSJCUkw0X0SsjIjhiBieoVm1mgXQvSmF3fYMjQX9roi4r5q8x/b8qj5f0mhvWgTQhEl3421b0h2SnoyIL48rrZG0QtLN1f0DPekQ9Zz9vmL5z067s9bbf/WLnynWf/Gxh2u9P5ozlc/s50taLulx25uraTdqLOTftn2VpGclXdGbFgE0YdKwR8RDktyhfGGz7QDoFb4uCyRB2IEkCDuQBGEHkiDsQBJc4nocmLb4vR1rI/fU+/rD4lXXFOuL7vz3Wu+P/mHLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJ79OPDUH3T+Yd/LZu/rWJuK0//lYPkFMeEPFGEAsWUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4z34MePWyc4v1dZfdWqgy5BbGsGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSSmMj77QknflDRPUkhaGRG3275J0mclPV+99MaIeLBXjWb2P+dPK9bfOb37c+l37T+tWJ+xr3w9O1ezHzum8qWaw5I+FxGP2j5J0iO211a12yLiS71rD0BTpjI++25Ju6vH+20/KWlBrxsD0Ky39Jnd9iJJH5K0oZp0re0ttlfZnvC3kWyP2N5ke9MhHajVLIDuTTnstk+UdK+k6yNin6SvSTpT0jka2/JP+AXtiFgZEcMRMTxDsxpoGUA3phR22zM0FvS7IuI+SYqIPRHxWkQckfR1SeWrNQC0atKw27akOyQ9GRFfHjd9/riXfVLS1ubbA9CUqRyNP1/SckmP295cTbtR0jLb52js7MsOSVf3pEPU8hcvLi7WH/6tRcV67H68wW7QpqkcjX9IkicocU4dOIbwDTogCcIOJEHYgSQIO5AEYQeSIOxAEo4+Drl7sufGeb6wb8sDstkQ67Qv9k50qpwtO5AFYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dfz7Lafl/TsuEmnSnqhbw28NYPa26D2JdFbt5rs7YyIeMdEhb6G/U0LtzdFxHBrDRQMam+D2pdEb93qV2/sxgNJEHYgibbDvrLl5ZcMam+D2pdEb93qS2+tfmYH0D9tb9kB9AlhB5JoJey2L7b9H7aftn1DGz10YnuH7cdtb7a9qeVeVtketb113LS5ttfa3l7dTzjGXku93WR7V7XuNtu+tKXeFtr+oe0nbG+zfV01vdV1V+irL+ut75/ZbU+T9J+SLpK0U9JGScsi4om+NtKB7R2ShiOi9S9g2P4NSS9J+mZEfKCadoukvRFxc/Uf5ZyI+JMB6e0mSS+1PYx3NVrR/PHDjEu6XNLvqsV1V+jrCvVhvbWxZT9X0tMR8UxEHJR0j6SlLfQx8CJivaS9b5i8VNLq6vFqjf2x9F2H3gZCROyOiEerx/slHR1mvNV1V+irL9oI+wJJPxn3fKcGa7z3kPQD24/YHmm7mQnMi4jd1ePnJM1rs5kJTDqMdz+9YZjxgVl33Qx/XhcH6N7sgoj4NUmXSLqm2l0dSDH2GWyQzp1OaRjvfplgmPGfa3PddTv8eV1thH2XpIXjnp9eTRsIEbGruh+VdL8GbyjqPUdH0K3uR1vu5+cGaRjviYYZ1wCsuzaHP28j7BslnWX7XbZnSrpS0poW+ngT20PVgRPZHpL0cQ3eUNRrJK2oHq+Q9ECLvbzOoAzj3WmYcbW87lof/jwi+n6TdKnGjsj/WNKfttFDh77eLemx6rat7d4k3a2x3bpDGju2cZWkUyStk7Rd0j9LmjtAvd0p6XFJWzQWrPkt9XaBxnbRt0jaXN0ubXvdFfrqy3rj67JAEhygA5Ig7EAShB1IgrADSRB2IAnCDiRB2IEk/h9BCfQTVPflJQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Check that we've loaded the data properly.\n",
    "# Print out the first channel of the first image in the test set\n",
    "plt.imshow(x_test[0,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Sequential)[0](Conv2D).b       32 (32, 1, 1)\n",
      "(Sequential)[0](Conv2D).w      800 (5, 5, 1, 32)\n",
      "(Sequential)[3](Conv2D).b       32 (32, 1, 1)\n",
      "(Sequential)[3](Conv2D).w    25600 (5, 5, 32, 32)\n",
      "(Sequential)[6](Conv2D).b       64 (64, 1, 1)\n",
      "(Sequential)[6](Conv2D).w    51200 (5, 5, 32, 64)\n",
      "(Sequential)[9](Conv2D).b       10 (10, 1, 1)\n",
      "(Sequential)[9](Conv2D).w      640 (1, 1, 64, 10)\n",
      "+Total(8)                    78378\n"
     ]
    }
   ],
   "source": [
    "# Create a classifier.\n",
    "# For now, we're going to start with three convolutional layers,\n",
    "# with ReLU activations after each\n",
    "\n",
    "# To compute the final predictions, we use a 1D convolution\n",
    "# converting the input to (batch_size, 10, 28, 28)\n",
    "# then we take the mean over the 28x28 image\n",
    "# (instead of a big fully-connected layer)\n",
    "\n",
    "def conv_relu_pool(in_layers, out_layers, pool=True):\n",
    "    ops = [objax.nn.Conv2D(in_layers, out_layers, 5),\n",
    "            objax.functional.relu]\n",
    "    if pool:\n",
    "        ops.append(lambda x: objax.functional.average_pool_2d(x, size=2, strides=1))\n",
    "    return ops\n",
    "\n",
    "model = objax.nn.Sequential(conv_relu_pool(1, 32) + \\\n",
    "                            conv_relu_pool(32, 32) + \\\n",
    "                            conv_relu_pool(32, 64) + \\\n",
    "                            [objax.nn.Conv2D(64, 10, 1),\n",
    "                             lambda x: x.mean((2,3))])\n",
    "\n",
    "# Print the model out to show what we'll be optimizing\n",
    "print(model.vars())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the optimizer to train the model.\n",
    "# We'll start with Adam as a good baseline optimizer.\n",
    "# We pass in the model variables to optimize\n",
    "\n",
    "opt = objax.optimizer.Adam(model.vars())\n",
    "\n",
    "# Next, we'll create our prediction function \n",
    "# We compute the softmax of the model predictions,\n",
    "# and wrap this with objax.Jit in order to run the JAX jit over everything.\n",
    "predict = objax.Jit(lambda x: objax.functional.softmax(model(x)), model.vars())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the loss function by generating model predictions, \n",
    "# and then taking the sparse softmax cross entropy predictions\n",
    "\n",
    "def loss(x, label):\n",
    "    logit = model(x)\n",
    "    return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()\n",
    "\n",
    "# Next, we're going to call objax.GradValues.\n",
    "# This returns a *function* that will return a tuple:\n",
    "# grad: the gradient of the loss on the input with respect to the model parameters\n",
    "# values: the values of the loss function on each input\n",
    "\n",
    "gv = objax.GradValues(loss, model.vars())\n",
    "\n",
    "# Finally, we create our function that will run the neural network training\n",
    "\n",
    "def train_op_nojit(x, y):\n",
    "    # Get the gradients and the losses at the current input\n",
    "    g, v = gv(x, y)\n",
    "    \n",
    "    # Then, run the optimizer passing in the learning rate and the gradients\n",
    "    opt(lr=0.002, grads=g)\n",
    "    \n",
    "    # Finally, return the loss on these examples\n",
    "    return v\n",
    "\n",
    "# Again we JIT the train_op so that it runs more quickly.\n",
    "train_op = objax.Jit(train_op_nojit, model.vars() + opt.vars())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 1.0557986\n",
      "loss 0.18019806\n",
      "loss 0.11240303\n",
      "loss 0.087404795\n",
      "loss 0.07405012\n",
      "loss 0.065644935\n",
      "loss 0.059288323\n",
      "loss 0.053326696\n",
      "loss 0.05036633\n",
      "loss 0.04767788\n",
      "model accuracy 0.9756\n"
     ]
    }
   ],
   "source": [
    "# Train for just ten epochs, because MNIST is simple enough for that.\n",
    "# This should be fast if you have a GPU attached to JAX\n",
    "# otherwise it will take a few minutes on a CPU\n",
    "def train_epoch(batch_size=50):    \n",
    "    # We'll save the losses after each minibatch\n",
    "    losses = []\n",
    "    \n",
    "    # Iterate over the training data, batch_size examples at a time\n",
    "    for x_batch, y_batch in zip(x_train.reshape((-1, batch_size, 1, 28, 28)), y_train.reshape((-1, batch_size))):\n",
    "        # And run the training op on each minibatch\n",
    "        losses.append(train_op(x_batch, y_batch))\n",
    "\n",
    "    return np.mean(losses)\n",
    "\n",
    "for i in range(10):\n",
    "    print('loss', train_epoch())\n",
    "    \n",
    "def compute_accuracy():\n",
    "    # Compute the final predictions on the test data\n",
    "    test_predictions = [predict(test_batch).argmax(1) for test_batch in x_test.reshape((-1, 50, 1, 28, 28))]\n",
    "    return np.mean(y_test == np.array(test_predictions).flatten())\n",
    "\n",
    "print('model accuracy', compute_accuracy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training a bigger model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Sequential)[0](Conv2D).b        32 (32, 1, 1)\n",
      "(Sequential)[0](Conv2D).w       800 (5, 5, 1, 32)\n",
      "(Sequential)[2](Conv2D).b        64 (64, 1, 1)\n",
      "(Sequential)[2](Conv2D).w     51200 (5, 5, 32, 64)\n",
      "(Sequential)[5](Conv2D).b        64 (64, 1, 1)\n",
      "(Sequential)[5](Conv2D).w    102400 (5, 5, 64, 64)\n",
      "(Sequential)[7](Conv2D).b       128 (128, 1, 1)\n",
      "(Sequential)[7](Conv2D).w    204800 (5, 5, 64, 128)\n",
      "(Sequential)[10](Conv2D).b       10 (10, 1, 1)\n",
      "(Sequential)[10](Conv2D).w     1280 (1, 1, 128, 10)\n",
      "+Total(10)                   360778\n"
     ]
    }
   ],
   "source": [
    "# Okay, now with a simple and tiny baseline model trained, let's train a\n",
    "# bigger model that we hope can do better.\n",
    "# Then, we'll try to prune it down to be just as small, but with better accuracy.\n",
    "\n",
    "model = objax.nn.Sequential(conv_relu_pool(1, 32, pool=False) + \\\n",
    "                            conv_relu_pool(32, 64) + \\\n",
    "                            conv_relu_pool(64, 64, pool=False) + \\\n",
    "                            conv_relu_pool(64, 128) + \\\n",
    "                            [objax.nn.Conv2D(128, 10, 1),\n",
    "                             lambda x: x.mean((2,3))])\n",
    "print(model.vars())\n",
    "predict = objax.Jit(lambda x: objax.functional.softmax(model(x)), model.vars())\n",
    "\n",
    "\n",
    "def loss_with_wd(x, label):\n",
    "    logit = model(x)\n",
    "    xe_loss = objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()\n",
    "    wd_loss = sum(jn.abs(v).sum() for k,v in model.vars().items() if k.endswith('.w'))\n",
    "    return xe_loss + wd_loss * 1e-5\n",
    "\n",
    "# Redefine the gradient and train ops\n",
    "\n",
    "opt = objax.optimizer.Adam(model.vars())\n",
    "gv = objax.GradValues(loss_with_wd, model.vars())\n",
    "train_op = objax.Jit(train_op_nojit, model.vars() + opt.vars())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 0.10089671\n",
      "loss 0.09586551\n",
      "loss 0.09251984\n",
      "loss 0.08294758\n",
      "loss 0.079793945\n",
      "loss 0.07640481\n",
      "loss 0.074100986\n",
      "loss 0.07067666\n",
      "loss 0.06935732\n",
      "loss 0.06614431\n",
      "model accuracy 0.9907\n",
      "Small weight ratio on layer (Sequential)[0](Conv2D).w 0.055\n",
      "Small weight ratio on layer (Sequential)[2](Conv2D).w 0.83347654\n",
      "Small weight ratio on layer (Sequential)[5](Conv2D).w 0.9338476\n",
      "Small weight ratio on layer (Sequential)[7](Conv2D).w 0.84509766\n",
      "Small weight ratio on layer (Sequential)[10](Conv2D).w 0.165625\n"
     ]
    }
   ],
   "source": [
    "# And train!\n",
    "for i in range(10):\n",
    "    random_shuffle = np.arange(x_train.shape[0])\n",
    "    np.random.shuffle(random_shuffle)\n",
    "    x_train = x_train[random_shuffle]\n",
    "    y_train = y_train[random_shuffle]\n",
    "    print('loss', train_epoch(batch_size=200))\n",
    "    \n",
    "print('model accuracy', compute_accuracy())\n",
    "for k,v in model.vars().items():\n",
    "    if k.endswith('.w'):\n",
    "        print(\"Small weight ratio on layer\", k, (jn.abs(v) < 1e-2).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss 0.089355305\n",
      "Loss 0.083341315\n",
      "Loss 0.07977875\n",
      "Loss 0.07744293\n",
      "Loss 0.075916454\n"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 186,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Sequential)[0](Conv2D).b        32 (32, 1, 1)\n",
      "(Sequential)[0](Conv2D).w       800 (5, 5, 1, 32)\n",
      "(Sequential)[2](Conv2D).b        64 (64, 1, 1)\n",
      "(Sequential)[2](Conv2D).w     51200 (5, 5, 32, 64)\n",
      "(Sequential)[5](Conv2D).b        64 (64, 1, 1)\n",
      "(Sequential)[5](Conv2D).w    102400 (5, 5, 64, 64)\n",
      "(Sequential)[7](Conv2D).b       128 (128, 1, 1)\n",
      "(Sequential)[7](Conv2D).w    204800 (5, 5, 64, 128)\n",
      "(Sequential)[10](Conv2D).b       10 (10, 1, 1)\n",
      "(Sequential)[10](Conv2D).w     1280 (1, 1, 128, 10)\n",
      "+Total(10)                   360778\n"
     ]
    }
   ],
   "source": [
    "print(model.vars())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
